from typing import Dict, Any
from .hotpot import HotpotDataset
from .sst2 import SST2Dataset
from .tydi import TydiDataset
from .cnn import CNNDataset

class EvalDataset:
    """
    Unified dataset loader that standardizes output format across different datasets.
    
    Supported datasets: 'hotpot', 'sst2', 'tydi', 'cnn'
    Output format: {"segment": list, "question": str, "system": str}
    
    The segment is returned as a list of strings that can be joined using the
    join_segments() helper function according to dataset-specific conventions.
    The system message is dataset-specific to guide appropriate responses.
    """
    
    def __init__(self, dataset_name: str, split: str = "test"):
        """
        Initialize the evaluation dataset.
        
        Args:
            dataset_name (str): Name of the dataset ('hotpot', 'sst2', 'tydi', 'cnn')
            split (str): Dataset split to load (default: 'test')
        """
        self.dataset_name = dataset_name.lower()
        self.split = split
        
        if self.dataset_name == "hotpot":
            self.dataset = HotpotDataset(split=split)
        elif self.dataset_name == "sst2":
            self.dataset = SST2Dataset(split=split)
        elif self.dataset_name == "tydi":
            # TydiDataset only supports 'train' split, so we use 'train' regardless
            self.dataset = TydiDataset(split="train")
        elif self.dataset_name == "cnn":
            self.dataset = CNNDataset(split=split)
        else:
            raise ValueError(f"Unsupported dataset: {dataset_name}. Supported datasets: 'hotpot', 'sst2', 'tydi', 'cnn'")
    
    def __len__(self) -> int:
        """Return the number of items in the dataset."""
        return len(self.dataset)
    
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """
        Get a single item from the dataset in standardized format.
        
        Args:
            idx (int): Index of the item
            
        Returns:
            Dict: Dictionary containing {"segment": list, "question": str, "system": str}
                  where segment is a list of strings that can be joined as needed
        """
        item = self.dataset[idx]
        
        if self.dataset_name == "hotpot":
            # For hotpot, use flattened contexts as segment list
            segment = item["sentences"]
            question = item["question"]
            system = "You are a helpful assistant that can answer questions concisely based only on the context provided."
            
        elif self.dataset_name == "sst2":
            # For sst2, use words as segment list and fixed sentiment question
            segment = item["words"]
            question = "What is the sentiment of the provided sentence? Response concisely."
            system = "You are a helpful assistant that can analyze the sentiment of a given text. Answer with only a single word,  'positive' 'negative' or 'neutral'."
            
        elif self.dataset_name == "tydi":
            # For tydi, use paragraphs as segment list
            segment = item["paragraphs"]
            question = item["question"]
            system = "You are a helpful assistant that can answer questions concisely based only on the context provided."
            
        elif self.dataset_name == "cnn":
            # For cnn, use sentences as segment list and create summarization question
            segment = item["sentences"]
            question = "Summarize the following news article in a few sentences."
            system = "You are a helpful assistant that can summarize news articles concisely based on the provided content."
            
        return {"segment": segment, "system": system, "question": question}
    
    def join_segments(self, segments: list) -> str:
        """
        Join segments according to this dataset's specific conventions.
        
        Args:
            segments (list): List of segment strings
            
        Returns:
            str: Joined segments according to dataset conventions
        """
        if self.dataset_name == "hotpot":
            # Join sentences with spaces
            return " ".join(segments)
        elif self.dataset_name == "sst2":
            # Join words with spaces
            return " ".join(segments)
        elif self.dataset_name == "tydi":
            # Join paragraphs with double newlines
            return "\n\n".join(segments)
        elif self.dataset_name == "cnn":
            # Join sentences with spaces
            return " ".join(segments)
        else:
            # Default: join with spaces
            return " ".join(segments)
